from contextlib import contextmanager
import numpy as np
import torch

from core.components.base_model import BaseModel
from core.modules.losses import L2Loss
from core.modules.subnetworks import Encoder, Decoder, Predictor
from core.utils.general_utils import AttrDict, ParamDict
from core.modules.layers import LayerBuilderParams
from core.utils.vis_utils import make_image_strip


class OracleModel(BaseModel):
    """Encoder model that is supervised with ground truth state representation."""
    def __init__(self, params, logger):
        super().__init__(logger)
        self._hp = self._default_hparams()
        self._hp.overwrite(params)
        self._hp.builder = LayerBuilderParams(self._hp.use_convs, self._hp.normalization)

        self.build_network()

    def _default_hparams(self):
        default_dict = ParamDict({
            'use_skips': False,
            'skips_stride': 2,
            'add_weighted_pixel_copy': False, # if True, adds pixel copying stream for decoder
            'pixel_shift_decoder': False,
            'use_convs': True,
            'use_custom_convs': False,
            'detach_reconstruction': True,
            'n_cond_frames': 1,
            'use_seg_mask': False,
            'detach_seg_mask': True,
            'seg_dec_activation': None,
            'n_class': 1,
            'use_random_rep': False,
            'use_obj_labels': False,
            'normalization': 'none',
        })

        # Network size
        default_dict.update({
            'state_dim': -1,
            'img_sz': 32,
            'input_nc': 3,
            'ngf': 8,
            'nz_enc': 32,
            'nz_mid': 32,
            'n_processing_layers': 3,
            'n_pixel_sources': 1,
        })

        # add new params to parent params
        parent_params = super()._default_hparams()
        parent_params.overwrite(default_dict)
        return parent_params

    def build_network(self):
        self.encoder = Encoder(self._hp)
        self.decoder = Decoder(self._hp)

        self.state_predictor = Predictor(self._hp,
                                         input_size=self._hp.nz_enc,
                                         output_size=self._hp.state_dim,
                                         spatial=False)

    def forward(self, inputs):
        """
        forward pass at training time
        """
        output = AttrDict()

        # encode inputs
        enc = self.encoder(inputs.images[:, 0])
        output.update({'pred': enc, 'rec_input': enc})

        # decode outputs
        rec_input = output.rec_input.detach() if self._hp.detach_reconstruction else output.rec_input
        output.output_imgs = self.decoder(rec_input).images.unsqueeze(1)

        # infer state
        output.s_hat = self.state_predictor(enc)

        return output

    def loss(self, model_output, inputs):
        losses = AttrDict()

        # image reconstruction loss
        losses.rec_loss = L2Loss(1.)(model_output.output_imgs, inputs.images[:, :1])

        # state prediction loss
        losses.state_pred_loss = L2Loss(1.)(model_output.s_hat, inputs.states[:, 0])

        losses.total = self._compute_total_loss(losses)
        return losses

    def log_outputs(self, model_output, inputs, losses, step, log_images, phase):
        super()._log_losses(losses, step, log_images, phase)
        if log_images:
            # log predicted images
            img_strip = make_image_strip([inputs.images[:, 0, -int(self._hp.input_nc//self._hp.n_frames):],
                                          model_output.output_imgs[:, 0, -int(self._hp.input_nc//self._hp.n_frames):]])
            self._logger.log_images(img_strip[None], 'generation', step, phase)

    def forward_encoder(self, inputs):
        enc = self.encoder(inputs)
        return enc

    @property
    def resolution(self):
        return self._hp.img_sz

    @contextmanager
    def val_mode(self):
        pass
        yield
        pass


